import random

from robot.robot_env import Map, MapObject, MapAgent, MovingAgent, CollisionManager, PathPlanner, TerraNodes
import matplotlib.patches as patches
import matplotlib.pyplot as plt

from slam import SlamMap, SlamMapTerraNode
from walle_event import SimpleEvent
from walle_priority_queue import WallePriorityQueue
from walle_world import WalleWorld


class SlamRobot(MovingAgent):
    def __init__(self, agent_id, current_node=None, cart=None, battery=1.0):
        MovingAgent.__init__(self, agent_id, current_node)
        self.cart = cart
        self.battery = battery
        self.moving_job = None
        self.charging_job = None
        self.sleeping_job = None
        self.robot_stopping_time = 2
        self.battery_consumption_per_terra = 0.001
        self.battery_charging_speed = 0.001
        self.batter_charging_finish = 0.95
        self.turn_back_time = 3
        self.color = "green"

    def get_render_patches(self, sq_width, sq_height):
        return patches.Circle((self.x * sq_width + 0.5 * sq_width, self.y * sq_height + 0.5 * sq_height), sq_height * 0.5, fill=True,
                                 facecolor=self.color,
                                 edgecolor='gray', linewidth=0.5, linestyle="solid")

    def on_moving_cart_job_event(self, tick, environment, agents, event):
        if self.charging_job is None:
            self.moving_job = event.message
            yield self.build_event_instance(tick, tick, "moving_job_state_event",
                                            {"id": self.moving_job.job_id, "message": "start"})
            if self.sleeping_job is not None:
                # 1. stop sleeping
                yield self.build_event_instance(tick, tick, "sleeping_job_state_event",
                                                {"message": "stop", "sleeping_node": self.sleeping_job.sleeping_node})
                self.sleeping_job = None
                # todo: stop current moving
                self.moving_plan = []
                # 2. move to from node
                new_event = {"target": self.moving_job.from_node}
                yield self.build_event_instance(tick, tick + self.robot_stopping_time, "request_move_to_target_event",
                                                new_event, self.agent_id)
            else:
                # 1. move to from node
                new_event = {"target": self.moving_job.from_node}
                yield self.build_event_instance(tick, tick + 1, "request_move_to_target_event", new_event, self.agent_id)
        else:
            yield self.build_event_instance(tick, tick, "moving_cart_job_failed_event", event.message)

    def on_charging_job_event(self, tick, environment, agents, event):
        if self.moving_job is None:
            self.charging_job = event.message
            if self.sleeping_job is not None:
                # 1. stop sleeping
                yield self.build_event_instance(tick, tick, "sleeping_job_state_event",
                                                {"message": "stop", "sleeping_node": self.sleeping_job.sleeping_node})
                self.sleeping_job = None
                self.moving_plan = []
                new_event = {"target": self.charging_job.charging_node}
                yield self.build_event_instance(tick, tick + self.robot_stopping_time, "request_move_to_target_event",
                                                new_event, self.agent_id)
            else:
                # 1. move to from node
                new_event = {"target": self.charging_job.charging_node}
                yield self.build_event_instance(tick, tick + 1, "request_move_to_target_event", new_event, self.agent_id)
        else:
            yield self.build_event_instance(tick, tick, "charging_job_failed_event", event.message)

    def on_sleeping_job_event(self, tick, environment, agents, event):
        if self.charging_job is None and self.moving_job is None:
            self.sleeping_job = event.message
            new_event = {"target": self.sleeping_job.sleeping_node}
            yield self.build_event_instance(tick, tick + 1, "request_move_to_target_event", new_event, self.agent_id)

    def on_charging_battery_event(self, tick, environment, agents, event):
        if self.battery < 1.0:
            self.battery = self.battery + self.battery_charging_speed
            yield self.build_event_instance(tick, tick + 1, "charging_battery_event", self.battery, self.agent_id)
        if self.battery > self.batter_charging_finish and self.charging_job is not None:
            yield self.build_event_instance(tick, tick, "charging_job_state_event",
                                            {"message": "stop", "charging_node": self.charging_job.charging_node})
            self.charging_job = None

    def on_moved_to_terra_event(self, tick, environment, agents, event):
        self.moving_history.append(event.message["next_node"].node_id)
        yield self.build_event_instance(tick, tick, "request_leave_terra_event",
                                        {"agent": self, "last_node": self.current_node})
        self.set_current_node(event.message["next_node"])
        if self.moving_target.node_id == self.current_node.node_id:
            yield self.build_event_instance(tick, tick, "moved_to_target_event",
                                            {"agent": self, "next_node": self.current_node})
        next_node = self.get_next_moving_terra()
        if next_node is not None:
            yield self.build_event_instance(tick, tick, "request_move_to_terra_event",
                                            {"agent": self, "next_node": next_node})
        # calculate battery consumption
        self.battery = self.battery - self.battery_consumption_per_terra
        yield self.build_event_instance(tick, tick, "robot_battery_event",
                                        {"agent": self, "battery": self.battery})

        if self.moving_job is not None:
            if self.current_node.node_id == self.moving_job.from_node.node_id:
                # 2. move to end node
                new_event = {"target": self.moving_job.to_node}
                yield self.build_event_instance(tick, tick + self.turn_back_time, "request_move_to_target_event", new_event, self.agent_id)
            elif self.current_node.node_id == self.moving_job.to_node.node_id:
                # 3. move job complete
                yield self.build_event_instance(tick, tick + 1, "moving_job_state_event",
                                                {"id": self.moving_job.job_id, "message": "stop"})
                self.moving_job = None
        if self.charging_job is not None:
            # start charging
            if self.current_node.node_id == self.charging_job.charging_node.node_id:
                yield self.build_event_instance(tick, tick + 1, "charging_battery_event", "", self.agent_id)
        if self.sleeping_job is not None:
            # start sleeping
            if self.current_node.node_id == self.sleeping_job.sleeping_node.node_id:
                self.sleeping_job = None
                yield self.build_event_instance(tick, tick, "sleeping_job_state_event", {"message": "start"})

    def calc_moving_time(self, next_terra):
        if self.cart is not None:
            return 1
        else:
            return 1


class MovingCartJob:
    def __init__(self, job_id, from_node, to_node, priority=0, agent=None, cart=None):
        self.job_id = job_id
        self.from_node = from_node
        self.to_node = to_node
        self.priority = 0
        self.agent = agent
        self.cart = cart


class ChargingJob:
    def __init__(self, charging_node, agent):
        self.charging_node = charging_node
        self.agent = agent


class SleepingJob:
    def __init__(self, sleeping_node, agent):
        self.sleeping_node = sleeping_node
        self.agent = agent

class SlamJobManager(MapAgent):
    def __init__(self, agent_id):
        MapAgent.__init__(self, agent_id)
        self.moving_job_queue = WallePriorityQueue()
        self.charging_agent_buffer = {}
        self.charging_map = {}
        self.sleeping_map = {}
        self.schedule_sleeping_time = 5
        self.battery_min = 0.2

    def on_walle_init_event(self, tick, environment, agents, event):
        map_grid = environment["map"]
        for node in map_grid.sleeping_nodes:
            self.sleeping_map[node.node_id] = (node, None)
        for node in map_grid.charging_nodes:
            self.charging_map[node.node_id] = (node, None)

    def get_empty_sleeping_node(self):
        for (node, agent) in self.sleeping_map.values():
            if agent is None:
                return node
        return None

    def get_empty_charging_node(self):
        for (node, agent) in self.charging_map.values():
            if agent is None:
                return node
        return None

    def on_request_dispatch_job_event(self, tick, environment, agents, event):
        # charge agent first
        to_charge_agents = set()
        for agent in self.charging_agent_buffer.values():
            if agent.moving_job is None and agent.charging_job is None:
                charging_node = self.get_empty_charging_node()
                if charging_node is not None:
                    self.charging_map[charging_node.node_id] = (charging_node, agent)
                    job = ChargingJob(charging_node, agent)
                    to_charge_agents.add(agent.agent_id)
                    yield self.build_event_instance(tick, tick, "charging_job_event", job, agent.agent_id)
        for agent_id in to_charge_agents:
            del self.charging_agent_buffer[agent_id]

        # for move job
        to_move_agents = set()
        while self.moving_job_queue.size() > 0:
            # get all robots without charging or moving job
            available_robot = set()
            for a in agents:
                agent = a["body"]
                if agent.__class__ is SlamRobot:
                    if agent.charging_job is None and agent.moving_job is None \
                            and agent.agent_id not in to_charge_agents and agent.agent_id not in to_move_agents:
                        available_robot.add(agent)
            # find nearest robot
            if len(available_robot) > 0:
                next_job = self.moving_job_queue.pop_task()
                min_agent = (100000, None)
                for agent in available_robot:
                    # todo: optimize distance calculation
                    # todo: consider battery for job
                    dist = SlamPathPlaner.get_distance(agent, next_job.from_node, environment["map"].height)
                    if dist < min_agent[0]:
                        min_agent = (dist, agent)
                to_move_agents.add(min_agent[1].agent_id)
                next_job.agent = min_agent[1]
                yield self.build_event_instance(tick, tick, "moving_cart_job_event", next_job, min_agent[1].agent_id)
            else:
                break
        # to sleep agent
        for a in agents:
            agent = a["body"]
            if agent.__class__ is SlamRobot:
                if agent.charging_job is None and agent.moving_job is None \
                        and agent.agent_id not in to_charge_agents and agent.agent_id not in to_move_agents\
                        and agent.current_node.node_type != SlamMapTerraNode.AVG_SLEEPING \
                        and agent.current_node.node_type != SlamMapTerraNode.AVG_CHARGING:
                    sleeping_node = self.get_empty_sleeping_node()
                    if sleeping_node is not None:
                        job = SleepingJob(sleeping_node, agent)
                        yield  self.build_event_instance(tick, tick, "sleeping_job_event", job, agent.agent_id)

    def on_moving_job_state_event(self, tick, environment, agents, event):
        if event.message["message"] == "stop":
            yield self.build_event_instance(tick, tick + 1,
                                            "request_dispatch_job_event", "")

    def on_charging_job_state_event(self, tick, environment, agents, event):
        if event.message["message"] == "stop":
            charging_node = event.message["charging_node"]
            if charging_node is not None:
                self.charging_map[charging_node.node_id] = (charging_node, None)
                yield self.build_event_instance(tick, tick + self.schedule_sleeping_time,
                                                "request_dispatch_job_event", "")

    def on_sleeping_job_state_event(self, tick, environment, agents, event):
        if event.message["message"] == "stop":
            sleeping_node = event.message["sleeping_node"]
            if sleeping_node is not None:
                self.sleeping_map[sleeping_node.node_id] = (sleeping_node, None)

    def on_robot_battery_event(self, tick, environment, agents, event):
        battery = event.message["battery"]
        if battery < self.battery_min:
            agent = event.message["agent"]
            self.charging_agent_buffer[agent.agent_id] = agent

    def on_request_move_cart_event(self, tick, environment, agents, event):
        map_grid = environment["map"]
        from_node_id = event.message["from_node_id"]
        to_node_id = event.message["to_node_id"]
        job_id = event.message["job_id"]
        if "priority" in event.message:
            priority = event.message["priority"]

        else:
            priority = 0
        from_node = map_grid.get_node_by_id(from_node_id)
        to_node = map_grid.get_node_by_id(to_node_id)
        job = MovingCartJob(job_id, from_node, to_node, priority)
        self.moving_job_queue.add_task(job, priority)
        yield self.build_event_instance(tick, tick + 1, "request_dispatch_job_event", "")

class SlamCollisionManager(CollisionManager):
    def __init__(self, agent_id):
        CollisionManager.__init__(self, agent_id)
        self.agent_owning_node = {}

    def on_request_leave_terra_event(self, tick, environment, agents, event):
        last_node = event.message["last_node"]
        if last_node.node_id in self.agent_owning_node:
            del self.agent_owning_node[last_node.node_id]

    def on_request_move_to_terra_event(self, tick, environment, agents, event):
        node = event.message["next_node"]
        if node.node_id in self.agent_owning_node and event.source != self.agent_owning_node[node.node_id]:
            yield self.build_event_instance(tick, tick, "forbidden_move_to_terra_event", event.message, event.source)
        else:
            self.agent_owning_node[node.node_id] = event.source
            yield self.build_event_instance(tick, tick, "approved_move_to_terra_event", event.message, event.source)


class SlamPathPlaner(PathPlanner):
    def __init__(self, agent_id, slam_map):
        PathPlanner.__init__(self, agent_id)
        self.slam_map = slam_map
        self.dist_cache = {}
        self.path_calc_time = 1

    @staticmethod
    def get_distance(start, end, height, width=0):
        start_in_left = (start.x < 3)
        end_in_left = (end.x < 3)
        if start_in_left == end_in_left:
            if start.y >= end.y:
                return start.y - end.y
            else:
                return height * 2 - (end.y - start.y)
        else:
            if start.y <= end.y:
                return end.y - start.y
            else:
                return height * 2 - (start.y - end.y)

    def reconstruct_path(self, cameFrom, current):
        total_path = [current]
        while current.node_id in cameFrom:
            current = cameFrom[current.node_id]
            total_path.append(current)
        return total_path

    def get_min(self, f_scores, open_set):
        min = 100000000
        min_id = None
        for id in open_set:
            value = f_scores[id]
            if value < min:
                min = value
                min_id = id
        return min_id

    def calc_path(self, start, target):
        # A* search
        close_set = set()
        open_set = set()
        open_set.add(start.node_id)
        came_from = {}
        g_scores = {}
        for row in self.slam_map.map:
            for node in row:
                g_scores[node.node_id] = 1000000
        g_scores[start.node_id] = 0
        f_scores = {}
        for row in self.slam_map.map:
            for node in row:
                if start.node_id == node.node_id:
                    f_scores[node.node_id] = SlamPathPlaner.get_distance(start, target, self.slam_map.height)
                else:
                    f_scores[node.node_id] = 1000000

        while len(open_set) != 0:
            # get min
            min_f_score_node = self.get_min(f_scores, open_set)
            if min_f_score_node == target.node_id:
                return self.reconstruct_path(came_from, target)

            open_set.remove(min_f_score_node)
            close_set.add(min_f_score_node)

            current_node = self.slam_map.get_node_by_id(min_f_score_node)
            for (x, y) in current_node.get_neighbours():
                neighbour = self.slam_map.get_node_by_xy(x, y)
                if neighbour is not None:
                    if neighbour.node_id in close_set:
                        continue
                    if neighbour.node_id not in open_set:
                        open_set.add(neighbour.node_id)

                    # todo: add weight
                    new_score = g_scores[current_node.node_id] + 1
                    if new_score >= g_scores[neighbour.node_id]:
                        continue

                    came_from[neighbour.node_id] = current_node
                    g_scores[neighbour.node_id] = new_score
                    f_scores[neighbour.node_id] = g_scores[neighbour.node_id] + SlamPathPlaner.get_distance(neighbour, target, self.slam_map.height)

        raise Exception("cannot find path")

    def on_request_moving_plan_event(self, tick, environment, agents, event):
        agent = event.message["agent"]
        target = event.message["moving_target"]
        start = agent.current_node
        try:
            path = self.calc_path(start, target)
            path.reverse()
            yield self.build_event_instance(tick, tick + self.path_calc_time, "moving_plan_event", {"moving_plan": path}, event.source)
        except:
            yield self.build_event_instance(tick, tick + self.path_calc_time, "moving_plan_event",
                                            {"moving_plan": []}, event.source)


if __name__ == "__main__":
    slam_map = SlamMap("slam_map")
    slam_map.terra_width = 2
    slam_map.terra_height = 1
    slam_map.load_map("slam_map.csv")

    robot1 = SlamRobot("robot1", slam_map.get_node_by_xy(1, 1))
    robot2 = SlamRobot("robot2", slam_map.get_node_by_xy(1, 2))
    robot3 = SlamRobot("robot3", slam_map.get_node_by_xy(1, 3))
    robot4 = SlamRobot("robot4", slam_map.get_node_by_xy(1, 4))
    robot5 = SlamRobot("robot5", slam_map.get_node_by_xy(1, 5))
    robots = [robot1, robot2, robot3, robot4, robot5]

    walle_world = WalleWorld("slam_world", {"map": slam_map})

    def render(tick, environment, agents):
        if tick % 1 == 0:
            environment["map"].render(robots=robots)
    walle_world.add_world_observer(render)

    i = 0
    def source(tick, environment, agents):
        global i
        if tick % 30 == 0:
            i = i + 1
            map = environment["map"]
            from_node_i = random.randint(0, len(map.out_queue_nodes) - 1)
            to_node_i = random.randint(0, len(map.in_queue_nodes) - 1)
            yield SimpleEvent(tick, tick,
                              "request_move_cart_event", "source-1", "job_manager",
                              {"job_id": "job-" + str(i), "from_node_id": map.out_queue_nodes[from_node_i].node_id, "to_node_id": map.in_queue_nodes[to_node_i].node_id})
    walle_world.add_event_source(source)

    agent0 = SlamCollisionManager("collision_manager")
    agent1 = SlamPathPlaner("path_planner", slam_map)
    agent2 = SlamJobManager("job_manager")
    walle_world.add_agent(agent0)
    walle_world.add_agent(agent1)
    walle_world.add_agent(agent2)
    walle_world.add_agent(robot1)
    walle_world.add_agent(robot2)
    walle_world.add_agent(robot3)
    walle_world.add_agent(robot4)
    walle_world.add_agent(robot5)

    #a = agent1.calc_path(slam_map.get_node_by_xy(1,13), slam_map.get_node_by_xy(1, 2))
    #print a

    walle_world.run(10000)
    walle_world.dump_event_log()

    while True:
        plt.pause(1.0)